{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Gradient Boosting Regression from Scratch\n",
    "\n",
    "The Gradient Boosting (GB) algorithm trains a series of weak learners and each focuses on the errors the previous learners have made and tries to improve it. Together, they make a better prediction.\n",
    "\n",
    "According to Wikipedia, Gradient boosting is a machine learning technique for regression and classification problems, which produces a prediction model in the form of an ensemble of weak prediction models, typically decision trees. It builds the model in a stage-wise fashion as other boosting methods do, and it generalizes them by allowing optimization of an arbitrary differentiable loss function. \n",
    "\n",
    "Prerequisite\n",
    "\n",
    "    1. Linear regression and gradient descent\n",
    "    2. Decision Tree\n",
    "\n",
    "After studying this post, you will be able to:\n",
    "\n",
    "    1. Explain gradient boosting algorithm.\n",
    "    2. Explain gradient boosting regression algorithm.\n",
    "    3. Write a gradient boosting regressor from scratch\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "    \n",
    "# The algorithm\n",
    "\n",
    "The following plot illustrates the algorithm.\n",
    "\n",
    "![Gradient Boosting Regression](gradient_boosting/gradient_boosting_regression.png)\n",
    "\n",
    "From the plot above, the first part is a stump, which is the average of **y**. We then add several trees to it. In the following trees, the target is not y. Instead, the target is the residual or the true value subtracts the previous prediction.\n",
    "\n",
    "$$residual=true\\_value - previous\\_prediction$$\n",
    "\n",
    "That is why we say in Gradient Boosting trains a series of weak learners, each focuses on the errors of the previous one. The residual predictions are multiplied by the learning rate (0.1 here) before added to the average.\n",
    "\n",
    "---\n",
    "\n",
    "**The Steps**\n",
    "\n",
    "Step 1: Calculate the average of y. The average is also the first estimation of y:\n",
    "$$\\bar{y}=\\frac{1}{n} \\sum_{i=1}^{n}y_i$$\n",
    "\n",
    "$$F_0(x)=\\bar{y}$$\n",
    "Step 2 for m in 1 to M: <br />\n",
    "  * Step 2.1: Compute so-call pseudo-residuals:\n",
    "    $$r_{im}=y_i-F_{m-1}(x_i)$$\n",
    "  * Step 2.2: Fit a regression tree $t_m(x)$ to pseudo-residuals and create terminal regions (leafs) $R_{jm}$ for $j=1...Jm$ <br />\n",
    "\n",
    "  * Step 2.3: For each leaf of the tree, there are $p_j$ elements, compute $\\gamma$ as following equation. <br />\n",
    "\n",
    "$$\\gamma_{im}=\\frac{1}{p_j} \\sum_{x_i \\in R_{jm}} r_{im}$$\n",
    "\n",
    "  * (In practise, the regression tree will do this for us.)\n",
    "\n",
    "  * Step 2.4: Update the model with learning rate $\\alpha$:\n",
    "$$F_m(x)=F_{m-1}+\\alpha\\gamma_m$$\n",
    "\n",
    "\n",
    "Step 3. Output $$F_M(x)$$\n",
    "\n",
    "---\n",
    "\n",
    "In practice the regression tree will average the leaf for us. Thus, Step 2.2 and 2.3 can be combined into one step. And the steps can be simplified:\n",
    "\n",
    "---\n",
    "\n",
    "**New The Steps**\n",
    "\n",
    "Step 1: Calculate the average of y. The average is also the first estimation of y:\n",
    "$$\\bar{y}=\\frac{1}{n} \\sum_{i=1}^{n}y_i$$\n",
    "\n",
    "$$F_0(x)=\\bar{y}$$\n",
    "Step 2 for m in 1 to M: <br />\n",
    "  * Step 2.1: Compute so-call pseudo-residuals:\n",
    "    $$r_{im}=y_i-F_{m-1}(x_i)$$\n",
    "  * Step 2.2: Fit a regression tree $t_m(x)$ to pseudo-residuals\n",
    "\n",
    "  * Step 2.3: Update the model with learning rate $\\alpha$:\n",
    "$$F_m(x)=F_{m-1}+\\alpha t_m(x)$$\n",
    "\n",
    "\n",
    "Step 3. Output $$F_M(x)$$\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# (Optional) From Gradient Boosting to Gradient Boosting Regression\n",
    "\n",
    "The above knowledge is enough for writing BGR code from scratch. But I want to explain more about gradient boosting. GB is a meta-algorithm that can be applied to both regression and classification. The above one is only a specific form for regression. In the following, I will introduce the general gradient boosting algorithm and deduce GBR from GB.\n",
    "\n",
    "Let's first look at the GB steps\n",
    "\n",
    "---\n",
    "\n",
    "**The Steps**\n",
    "\n",
    "Input: training set $\\{(x_i, y_i)\\}_{i=1}^{n}$, a differentiable loss function $L(y, F(x))$, number of iterations M\n",
    "\n",
    "Algorithm:\n",
    "\n",
    "Step 1: Initialize model with a constant value:\n",
    "\n",
    "$$F_0(x)=\\underset{\\gamma}{\\operatorname{argmin}}\\sum_{i=1}^{n}L(y_i, \\gamma)$$\n",
    "\n",
    "Step 2 for m in 1 to M: <br />\n",
    "  * Step 2.1: Compute so-call pseudo-residuals:\n",
    "    $$r_{im}=-[\\frac{\\partial L(y_i, F(x_i))}{\\partial F(x_i)}]_{F(x)=F_{m-1}(x)}$$\n",
    "  * Step 2.2: Fit a weak learner $h_m(x)$ to pseudo-residuals. and create terminal regions $R_{jm}$, for $j=1...J_m$<br />\n",
    "\n",
    "  * Step 2.3: For each leaf of the tree, compute $\\gamma$ as the following equation. Here $\\hat{r}$ is the predicted residual produced by $h_m(x)$.<br />\n",
    "$$\\gamma_{jm}=\\underset{\\gamma}{\\operatorname{argmin}}\\sum_{x_i \\in R_{jm}}^{n}L(y_i, F_{m-1}(x_i)+\\gamma)$$\n",
    "\n",
    "  * Step 2.4: Update the model with learning rate $\\alpha$:\n",
    "$$F_m(x)=F_{m-1}+\\alpha\\gamma_m$$\n",
    "\n",
    "\n",
    "Step 3. Output $$F_M(x)$$\n",
    "\n",
    "---\n",
    "\n",
    "To deduce the GB to GBR, I simply define a loss function and solve the loss function in step 1, 2.1 and 2.3. We use sum of squared errror(SSE) as the loss function:\n",
    "\n",
    "$$L(y, \\gamma)=\\frac{1}{2}\\sum_{i=1}^{n}(y_i-\\gamma)^2$$\n",
    "\n",
    "\n",
    "\n",
    "For step 1:\n",
    "\n",
    "Because SSE is a convex and at the lowest point where the derivative is zero, we have the following:\n",
    "\n",
    "$$\\frac{\\partial L(y, F_0)}{\\partial F_0}=\\frac{\\partial \\frac{1}{2}\\sum_{i=1}^{n}(y_i-F_0)^2}{\\partial F_0}\n",
    "=\\sum_{i=1}^{n} (y_i-F_0)=0\n",
    "$$\n",
    "\n",
    "Thus, we have:\n",
    "\n",
    "$$F_0=\\frac{1}{n}\\sum_{i=1}^{n}y_i$$\n",
    "\n",
    "For step 2.1:\n",
    "\n",
    "$$r_{im}=-[\\frac{\\partial L(y_i, F(x_i))}{\\partial F(x_i)}]_{F(x)=F_{m-1}(x)}$$\n",
    "\n",
    "$$=-[\\frac{\\partial \\frac{1}{2}\\sum_{i=1}^{n}(y_i-F_{m-1}(x_i))^2)}{\\partial F_{m-1}(x_i)}]_{F(x)=F_{m-1}(x)}$$\n",
    "\n",
    "(The chain rule)\n",
    "\n",
    "$$=--2*\\frac{1}{2}(y_i-F_{m-1}(x_i))$$\n",
    "$$=y_i-F_{m-1}(x_i)$$\n",
    "\n",
    "For step 2.3:\n",
    "\n",
    "Similarly, the result is:\n",
    "\n",
    "$$\\gamma_{jm}=\\frac{1}{p_j}\\sum_{x_i \\in R_j}r_{mi}$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0.5\n",
      "0.23.1\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import sklearn\n",
    "from sklearn.tree import DecisionTreeRegressor\n",
    "from sklearn.datasets import load_boston\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import graphviz \n",
    "from sklearn import tree\n",
    "\n",
    "print(pd.__version__)\n",
    "print(sklearn.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>name</th>\n",
       "      <th>height</th>\n",
       "      <th>gender</th>\n",
       "      <th>weight</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Alex</td>\n",
       "      <td>1.6</td>\n",
       "      <td>male</td>\n",
       "      <td>88</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Brunei</td>\n",
       "      <td>1.6</td>\n",
       "      <td>female</td>\n",
       "      <td>76</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Candy</td>\n",
       "      <td>1.5</td>\n",
       "      <td>female</td>\n",
       "      <td>56</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>David</td>\n",
       "      <td>1.8</td>\n",
       "      <td>male</td>\n",
       "      <td>73</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Eric</td>\n",
       "      <td>1.5</td>\n",
       "      <td>male</td>\n",
       "      <td>77</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>Felicity</td>\n",
       "      <td>1.4</td>\n",
       "      <td>female</td>\n",
       "      <td>57</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       name  height  gender  weight\n",
       "0      Alex     1.6    male      88\n",
       "1    Brunei     1.6  female      76\n",
       "2     Candy     1.5  female      56\n",
       "3     David     1.8    male      73\n",
       "4      Eric     1.5    male      77\n",
       "5  Felicity     1.4  female      57"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>height</th>\n",
       "      <th>gender</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1.6</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1.6</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1.5</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1.8</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1.5</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>1.4</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   height gender\n",
       "0     1.6      1\n",
       "1     1.6      0\n",
       "2     1.5      0\n",
       "3     1.8      1\n",
       "4     1.5      1\n",
       "5     1.4      0"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "df=pd.DataFrame()\n",
    "df['name']=['Alex','Brunei','Candy','David','Eric','Felicity']\n",
    "df['height']=[1.6,1.6,1.5,1.8,1.5,1.4]\n",
    "df['gender']=['male','female','female','male','male','female']\n",
    "df['weight']=[88, 76, 56, 73, 77, 57]\n",
    "display(df)\n",
    "\n",
    "X=df[['height','gender']].copy()\n",
    "X.loc[X['gender']=='male','gender']=1\n",
    "X.loc[X['gender']=='female','gender']=0\n",
    "y=df['weight']\n",
    "display(X)\n",
    "\n",
    "n=df.shape[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1 Average"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>name</th>\n",
       "      <th>height</th>\n",
       "      <th>gender</th>\n",
       "      <th>weight</th>\n",
       "      <th>$f_0$</th>\n",
       "      <th>$r_0$</th>\n",
       "      <th>$\\gamma_1$</th>\n",
       "      <th>$f_1$</th>\n",
       "      <th>$r_1$</th>\n",
       "      <th>$\\gamma_2$</th>\n",
       "      <th>...</th>\n",
       "      <th>$r_2$</th>\n",
       "      <th>$\\gamma_3$</th>\n",
       "      <th>$f_3$</th>\n",
       "      <th>$r_3$</th>\n",
       "      <th>$\\gamma_4$</th>\n",
       "      <th>$f_4$</th>\n",
       "      <th>$r_4$</th>\n",
       "      <th>$\\gamma_5$</th>\n",
       "      <th>$f_5$</th>\n",
       "      <th>$r_5$</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Alex</td>\n",
       "      <td>1.6</td>\n",
       "      <td>male</td>\n",
       "      <td>88</td>\n",
       "      <td>71.166667</td>\n",
       "      <td>16.833333</td>\n",
       "      <td>8.166667</td>\n",
       "      <td>72.800000</td>\n",
       "      <td>15.200000</td>\n",
       "      <td>7.288889</td>\n",
       "      <td>...</td>\n",
       "      <td>13.742222</td>\n",
       "      <td>6.047407</td>\n",
       "      <td>75.467259</td>\n",
       "      <td>12.532741</td>\n",
       "      <td>5.427951</td>\n",
       "      <td>76.552849</td>\n",
       "      <td>11.447151</td>\n",
       "      <td>4.476063</td>\n",
       "      <td>77.448062</td>\n",
       "      <td>10.551938</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Brunei</td>\n",
       "      <td>1.6</td>\n",
       "      <td>female</td>\n",
       "      <td>76</td>\n",
       "      <td>71.166667</td>\n",
       "      <td>4.833333</td>\n",
       "      <td>-8.166667</td>\n",
       "      <td>69.533333</td>\n",
       "      <td>6.466667</td>\n",
       "      <td>7.288889</td>\n",
       "      <td>...</td>\n",
       "      <td>5.008889</td>\n",
       "      <td>-6.047407</td>\n",
       "      <td>69.781630</td>\n",
       "      <td>6.218370</td>\n",
       "      <td>5.427951</td>\n",
       "      <td>70.867220</td>\n",
       "      <td>5.132780</td>\n",
       "      <td>-4.476063</td>\n",
       "      <td>69.972007</td>\n",
       "      <td>6.027993</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Candy</td>\n",
       "      <td>1.5</td>\n",
       "      <td>female</td>\n",
       "      <td>56</td>\n",
       "      <td>71.166667</td>\n",
       "      <td>-15.166667</td>\n",
       "      <td>-8.166667</td>\n",
       "      <td>69.533333</td>\n",
       "      <td>-13.533333</td>\n",
       "      <td>-7.288889</td>\n",
       "      <td>...</td>\n",
       "      <td>-12.075556</td>\n",
       "      <td>-6.047407</td>\n",
       "      <td>66.866074</td>\n",
       "      <td>-10.866074</td>\n",
       "      <td>-5.427951</td>\n",
       "      <td>65.780484</td>\n",
       "      <td>-9.780484</td>\n",
       "      <td>-4.476063</td>\n",
       "      <td>64.885271</td>\n",
       "      <td>-8.885271</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>David</td>\n",
       "      <td>1.8</td>\n",
       "      <td>male</td>\n",
       "      <td>73</td>\n",
       "      <td>71.166667</td>\n",
       "      <td>1.833333</td>\n",
       "      <td>8.166667</td>\n",
       "      <td>72.800000</td>\n",
       "      <td>0.200000</td>\n",
       "      <td>7.288889</td>\n",
       "      <td>...</td>\n",
       "      <td>-1.257778</td>\n",
       "      <td>6.047407</td>\n",
       "      <td>75.467259</td>\n",
       "      <td>-2.467259</td>\n",
       "      <td>5.427951</td>\n",
       "      <td>76.552849</td>\n",
       "      <td>-3.552849</td>\n",
       "      <td>4.476063</td>\n",
       "      <td>77.448062</td>\n",
       "      <td>-4.448062</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Eric</td>\n",
       "      <td>1.5</td>\n",
       "      <td>male</td>\n",
       "      <td>77</td>\n",
       "      <td>71.166667</td>\n",
       "      <td>5.833333</td>\n",
       "      <td>8.166667</td>\n",
       "      <td>72.800000</td>\n",
       "      <td>4.200000</td>\n",
       "      <td>-7.288889</td>\n",
       "      <td>...</td>\n",
       "      <td>5.657778</td>\n",
       "      <td>6.047407</td>\n",
       "      <td>72.551704</td>\n",
       "      <td>4.448296</td>\n",
       "      <td>-5.427951</td>\n",
       "      <td>71.466114</td>\n",
       "      <td>5.533886</td>\n",
       "      <td>4.476063</td>\n",
       "      <td>72.361326</td>\n",
       "      <td>4.638674</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>Felicity</td>\n",
       "      <td>1.4</td>\n",
       "      <td>female</td>\n",
       "      <td>57</td>\n",
       "      <td>71.166667</td>\n",
       "      <td>-14.166667</td>\n",
       "      <td>-8.166667</td>\n",
       "      <td>69.533333</td>\n",
       "      <td>-12.533333</td>\n",
       "      <td>-7.288889</td>\n",
       "      <td>...</td>\n",
       "      <td>-11.075556</td>\n",
       "      <td>-6.047407</td>\n",
       "      <td>66.866074</td>\n",
       "      <td>-9.866074</td>\n",
       "      <td>-5.427951</td>\n",
       "      <td>65.780484</td>\n",
       "      <td>-8.780484</td>\n",
       "      <td>-4.476063</td>\n",
       "      <td>64.885271</td>\n",
       "      <td>-7.885271</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>6 rows × 21 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       name  height  gender  weight      $f_0$      $r_0$  $\\gamma_1$  \\\n",
       "0      Alex     1.6    male      88  71.166667  16.833333    8.166667   \n",
       "1    Brunei     1.6  female      76  71.166667   4.833333   -8.166667   \n",
       "2     Candy     1.5  female      56  71.166667 -15.166667   -8.166667   \n",
       "3     David     1.8    male      73  71.166667   1.833333    8.166667   \n",
       "4      Eric     1.5    male      77  71.166667   5.833333    8.166667   \n",
       "5  Felicity     1.4  female      57  71.166667 -14.166667   -8.166667   \n",
       "\n",
       "       $f_1$      $r_1$  $\\gamma_2$  ...      $r_2$  $\\gamma_3$      $f_3$  \\\n",
       "0  72.800000  15.200000    7.288889  ...  13.742222    6.047407  75.467259   \n",
       "1  69.533333   6.466667    7.288889  ...   5.008889   -6.047407  69.781630   \n",
       "2  69.533333 -13.533333   -7.288889  ... -12.075556   -6.047407  66.866074   \n",
       "3  72.800000   0.200000    7.288889  ...  -1.257778    6.047407  75.467259   \n",
       "4  72.800000   4.200000   -7.288889  ...   5.657778    6.047407  72.551704   \n",
       "5  69.533333 -12.533333   -7.288889  ... -11.075556   -6.047407  66.866074   \n",
       "\n",
       "       $r_3$  $\\gamma_4$      $f_4$      $r_4$  $\\gamma_5$      $f_5$  \\\n",
       "0  12.532741    5.427951  76.552849  11.447151    4.476063  77.448062   \n",
       "1   6.218370    5.427951  70.867220   5.132780   -4.476063  69.972007   \n",
       "2 -10.866074   -5.427951  65.780484  -9.780484   -4.476063  64.885271   \n",
       "3  -2.467259    5.427951  76.552849  -3.552849    4.476063  77.448062   \n",
       "4   4.448296   -5.427951  71.466114   5.533886    4.476063  72.361326   \n",
       "5  -9.866074   -5.427951  65.780484  -8.780484   -4.476063  64.885271   \n",
       "\n",
       "       $r_5$  \n",
       "0  10.551938  \n",
       "1   6.027993  \n",
       "2  -8.885271  \n",
       "3  -4.448062  \n",
       "4   4.638674  \n",
       "5  -7.885271  \n",
       "\n",
       "[6 rows x 21 columns]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#now let's get started\n",
    "learning_rate=0.2\n",
    "loss = [0] * 6\n",
    "residuals = np.zeros([6,n])\n",
    "predictoin = np.zeros([6,n])\n",
    "#calculation\n",
    "average_y=y.mean()\n",
    "predictoin[0] = [average_y] * n\n",
    "residuals[0] = y - predictoin[0]\n",
    "df['$f_0$']=predictoin[0]\n",
    "df['$r_0$']=residuals[0]\n",
    "display(df)\n",
    "loss[0] = np.sum(residuals[0] ** 2)/n\n",
    "trees = []"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the first step, we calculate the average 71.2 as the initial prediction. The pseudo residuals are 16.8, 4.8, etc."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2 For Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "def iterate(i):\n",
    "    t = DecisionTreeRegressor(max_depth=1)\n",
    "    t.fit(X,residuals[i])\n",
    "    trees.append(t)\n",
    "    #next prediction, residual\n",
    "    predictoin[i+1]=predictoin[i]+learning_rate * t.predict(X)\n",
    "    residuals[i+1]=y-predictoin[i+1]\n",
    "    loss[i+1] = np.sum(residuals[i+1] ** 2)/n\n",
    "    \n",
    "    df[f'$\\gamma_{i+1}$']=t.predict(X)\n",
    "    df[f'$f_{i+1}$']=predictoin[i+1]\n",
    "    df[f'$r_{i+1}$']=residuals[i+1]\n",
    "    \n",
    "    display(df[['name','height','gender','weight',f'$f_{i}$',f'$r_{i}$',f'$\\gamma_{i+1}$',f'$f_{i+1}$',f'$r_{i+1}$']])\n",
    "    \n",
    "    dot_data = tree.export_graphviz(t, out_file=None, filled=True, rounded=True,feature_names=X.columns) \n",
    "    graph = graphviz.Source(dot_data) \n",
    "    display(graph)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>name</th>\n",
       "      <th>height</th>\n",
       "      <th>gender</th>\n",
       "      <th>weight</th>\n",
       "      <th>$f_0$</th>\n",
       "      <th>$r_0$</th>\n",
       "      <th>$\\gamma_1$</th>\n",
       "      <th>$f_1$</th>\n",
       "      <th>$r_1$</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Alex</td>\n",
       "      <td>1.6</td>\n",
       "      <td>male</td>\n",
       "      <td>88</td>\n",
       "      <td>71.166667</td>\n",
       "      <td>16.833333</td>\n",
       "      <td>8.166667</td>\n",
       "      <td>72.800000</td>\n",
       "      <td>15.200000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Brunei</td>\n",
       "      <td>1.6</td>\n",
       "      <td>female</td>\n",
       "      <td>76</td>\n",
       "      <td>71.166667</td>\n",
       "      <td>4.833333</td>\n",
       "      <td>-8.166667</td>\n",
       "      <td>69.533333</td>\n",
       "      <td>6.466667</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Candy</td>\n",
       "      <td>1.5</td>\n",
       "      <td>female</td>\n",
       "      <td>56</td>\n",
       "      <td>71.166667</td>\n",
       "      <td>-15.166667</td>\n",
       "      <td>-8.166667</td>\n",
       "      <td>69.533333</td>\n",
       "      <td>-13.533333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>David</td>\n",
       "      <td>1.8</td>\n",
       "      <td>male</td>\n",
       "      <td>73</td>\n",
       "      <td>71.166667</td>\n",
       "      <td>1.833333</td>\n",
       "      <td>8.166667</td>\n",
       "      <td>72.800000</td>\n",
       "      <td>0.200000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Eric</td>\n",
       "      <td>1.5</td>\n",
       "      <td>male</td>\n",
       "      <td>77</td>\n",
       "      <td>71.166667</td>\n",
       "      <td>5.833333</td>\n",
       "      <td>8.166667</td>\n",
       "      <td>72.800000</td>\n",
       "      <td>4.200000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>Felicity</td>\n",
       "      <td>1.4</td>\n",
       "      <td>female</td>\n",
       "      <td>57</td>\n",
       "      <td>71.166667</td>\n",
       "      <td>-14.166667</td>\n",
       "      <td>-8.166667</td>\n",
       "      <td>69.533333</td>\n",
       "      <td>-12.533333</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       name  height  gender  weight      $f_0$      $r_0$  $\\gamma_1$  \\\n",
       "0      Alex     1.6    male      88  71.166667  16.833333    8.166667   \n",
       "1    Brunei     1.6  female      76  71.166667   4.833333   -8.166667   \n",
       "2     Candy     1.5  female      56  71.166667 -15.166667   -8.166667   \n",
       "3     David     1.8    male      73  71.166667   1.833333    8.166667   \n",
       "4      Eric     1.5    male      77  71.166667   5.833333    8.166667   \n",
       "5  Felicity     1.4  female      57  71.166667 -14.166667   -8.166667   \n",
       "\n",
       "       $f_1$      $r_1$  \n",
       "0  72.800000  15.200000  \n",
       "1  69.533333   6.466667  \n",
       "2  69.533333 -13.533333  \n",
       "3  72.800000   0.200000  \n",
       "4  72.800000   4.200000  \n",
       "5  69.533333 -12.533333  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
       "<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
       " -->\r\n",
       "<!-- Title: Tree Pages: 1 -->\r\n",
       "<svg width=\"238pt\" height=\"165pt\"\r\n",
       " viewBox=\"0.00 0.00 238.00 165.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 161)\">\r\n",
       "<title>Tree</title>\r\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-161 234,-161 234,4 -4,4\"/>\r\n",
       "<!-- 0 -->\r\n",
       "<g id=\"node1\" class=\"node\"><title>0</title>\r\n",
       "<path fill=\"#f2c09c\" stroke=\"black\" d=\"M160,-157C160,-157 71,-157 71,-157 65,-157 59,-151 59,-145 59,-145 59,-101 59,-101 59,-95 65,-89 71,-89 71,-89 160,-89 160,-89 166,-89 172,-95 172,-101 172,-101 172,-145 172,-145 172,-151 166,-157 160,-157\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"115.5\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">gender &lt;= 0.5</text>\r\n",
       "<text text-anchor=\"middle\" x=\"115.5\" y=\"-126.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">mse = 129.139</text>\r\n",
       "<text text-anchor=\"middle\" x=\"115.5\" y=\"-111.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 6</text>\r\n",
       "<text text-anchor=\"middle\" x=\"115.5\" y=\"-96.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = &#45;0.0</text>\r\n",
       "</g>\r\n",
       "<!-- 1 -->\r\n",
       "<g id=\"node2\" class=\"node\"><title>1</title>\r\n",
       "<path fill=\"#ffffff\" stroke=\"black\" d=\"M95,-53C95,-53 12,-53 12,-53 6,-53 0,-47 0,-41 0,-41 0,-12 0,-12 0,-6 6,-0 12,-0 12,-0 95,-0 95,-0 101,-0 107,-6 107,-12 107,-12 107,-41 107,-41 107,-47 101,-53 95,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"53.5\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">mse = 84.667</text>\r\n",
       "<text text-anchor=\"middle\" x=\"53.5\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 3</text>\r\n",
       "<text text-anchor=\"middle\" x=\"53.5\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = &#45;8.167</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;1 -->\r\n",
       "<g id=\"edge1\" class=\"edge\"><title>0&#45;&gt;1</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M93.8154,-88.9485C88.0536,-80.1664 81.8185,-70.6629 76.0136,-61.815\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"78.7975,-59.6779 70.3855,-53.2367 72.9447,-63.5178 78.7975,-59.6779\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"65.3022\" y=\"-74.0145\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">True</text>\r\n",
       "</g>\r\n",
       "<!-- 2 -->\r\n",
       "<g id=\"node3\" class=\"node\"><title>2</title>\r\n",
       "<path fill=\"#e58139\" stroke=\"black\" d=\"M218,-53C218,-53 137,-53 137,-53 131,-53 125,-47 125,-41 125,-41 125,-12 125,-12 125,-6 131,-0 137,-0 137,-0 218,-0 218,-0 224,-0 230,-6 230,-12 230,-12 230,-41 230,-41 230,-47 224,-53 218,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"177.5\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">mse = 40.222</text>\r\n",
       "<text text-anchor=\"middle\" x=\"177.5\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 3</text>\r\n",
       "<text text-anchor=\"middle\" x=\"177.5\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = 8.167</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;2 -->\r\n",
       "<g id=\"edge2\" class=\"edge\"><title>0&#45;&gt;2</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M137.185,-88.9485C142.946,-80.1664 149.181,-70.6629 154.986,-61.815\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"158.055,-63.5178 160.615,-53.2367 152.203,-59.6779 158.055,-63.5178\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"165.698\" y=\"-74.0145\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">False</text>\r\n",
       "</g>\r\n",
       "</g>\r\n",
       "</svg>\r\n"
      ],
      "text/plain": [
       "<graphviz.files.Source at 0x2858a2fefa0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "iterate(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In Iteration 0, we first train a tree using residuals_0. This tree tells us that males are higher than females, each male should add 8.167 kg, each female subtracts 8.167 kg. The result of the tree is $\\gamma$ here. But, we want to take just a small step a time, so we multiply the learning rate $\\alpha=0.2$. Thus, the new prediction is $prediction=prediction+\\alpha\\gamma$. That is to say, for each male, he will add $8.167*0.2=1.6334$ kg. For each female, she will lose $-8.167*0.2=-1.6334$ kg. Finally, the males are predicted 72.8 kg and females 69.5. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>name</th>\n",
       "      <th>height</th>\n",
       "      <th>gender</th>\n",
       "      <th>weight</th>\n",
       "      <th>$f_1$</th>\n",
       "      <th>$r_1$</th>\n",
       "      <th>$\\gamma_2$</th>\n",
       "      <th>$f_2$</th>\n",
       "      <th>$r_2$</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Alex</td>\n",
       "      <td>1.6</td>\n",
       "      <td>male</td>\n",
       "      <td>88</td>\n",
       "      <td>72.800000</td>\n",
       "      <td>15.200000</td>\n",
       "      <td>7.288889</td>\n",
       "      <td>74.257778</td>\n",
       "      <td>13.742222</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Brunei</td>\n",
       "      <td>1.6</td>\n",
       "      <td>female</td>\n",
       "      <td>76</td>\n",
       "      <td>69.533333</td>\n",
       "      <td>6.466667</td>\n",
       "      <td>7.288889</td>\n",
       "      <td>70.991111</td>\n",
       "      <td>5.008889</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Candy</td>\n",
       "      <td>1.5</td>\n",
       "      <td>female</td>\n",
       "      <td>56</td>\n",
       "      <td>69.533333</td>\n",
       "      <td>-13.533333</td>\n",
       "      <td>-7.288889</td>\n",
       "      <td>68.075556</td>\n",
       "      <td>-12.075556</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>David</td>\n",
       "      <td>1.8</td>\n",
       "      <td>male</td>\n",
       "      <td>73</td>\n",
       "      <td>72.800000</td>\n",
       "      <td>0.200000</td>\n",
       "      <td>7.288889</td>\n",
       "      <td>74.257778</td>\n",
       "      <td>-1.257778</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Eric</td>\n",
       "      <td>1.5</td>\n",
       "      <td>male</td>\n",
       "      <td>77</td>\n",
       "      <td>72.800000</td>\n",
       "      <td>4.200000</td>\n",
       "      <td>-7.288889</td>\n",
       "      <td>71.342222</td>\n",
       "      <td>5.657778</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>Felicity</td>\n",
       "      <td>1.4</td>\n",
       "      <td>female</td>\n",
       "      <td>57</td>\n",
       "      <td>69.533333</td>\n",
       "      <td>-12.533333</td>\n",
       "      <td>-7.288889</td>\n",
       "      <td>68.075556</td>\n",
       "      <td>-11.075556</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       name  height  gender  weight      $f_1$      $r_1$  $\\gamma_2$  \\\n",
       "0      Alex     1.6    male      88  72.800000  15.200000    7.288889   \n",
       "1    Brunei     1.6  female      76  69.533333   6.466667    7.288889   \n",
       "2     Candy     1.5  female      56  69.533333 -13.533333   -7.288889   \n",
       "3     David     1.8    male      73  72.800000   0.200000    7.288889   \n",
       "4      Eric     1.5    male      77  72.800000   4.200000   -7.288889   \n",
       "5  Felicity     1.4  female      57  69.533333 -12.533333   -7.288889   \n",
       "\n",
       "       $f_2$      $r_2$  \n",
       "0  74.257778  13.742222  \n",
       "1  70.991111   5.008889  \n",
       "2  68.075556 -12.075556  \n",
       "3  74.257778  -1.257778  \n",
       "4  71.342222   5.657778  \n",
       "5  68.075556 -11.075556  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
       "<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
       " -->\r\n",
       "<!-- Title: Tree Pages: 1 -->\r\n",
       "<svg width=\"238pt\" height=\"165pt\"\r\n",
       " viewBox=\"0.00 0.00 238.00 165.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 161)\">\r\n",
       "<title>Tree</title>\r\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-161 234,-161 234,4 -4,4\"/>\r\n",
       "<!-- 0 -->\r\n",
       "<g id=\"node1\" class=\"node\"><title>0</title>\r\n",
       "<path fill=\"#f2c09c\" stroke=\"black\" d=\"M160,-157C160,-157 71,-157 71,-157 65,-157 59,-151 59,-145 59,-145 59,-101 59,-101 59,-95 65,-89 71,-89 71,-89 160,-89 160,-89 166,-89 172,-95 172,-101 172,-101 172,-145 172,-145 172,-151 166,-157 160,-157\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"115.5\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">height &lt;= 1.55</text>\r\n",
       "<text text-anchor=\"middle\" x=\"115.5\" y=\"-126.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">mse = 105.129</text>\r\n",
       "<text text-anchor=\"middle\" x=\"115.5\" y=\"-111.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 6</text>\r\n",
       "<text text-anchor=\"middle\" x=\"115.5\" y=\"-96.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = 0.0</text>\r\n",
       "</g>\r\n",
       "<!-- 1 -->\r\n",
       "<g id=\"node2\" class=\"node\"><title>1</title>\r\n",
       "<path fill=\"#ffffff\" stroke=\"black\" d=\"M95,-53C95,-53 12,-53 12,-53 6,-53 0,-47 0,-41 0,-41 0,-12 0,-12 0,-6 6,-0 12,-0 12,-0 95,-0 95,-0 101,-0 107,-6 107,-12 107,-12 107,-41 107,-41 107,-47 101,-53 95,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"53.5\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">mse = 66.164</text>\r\n",
       "<text text-anchor=\"middle\" x=\"53.5\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 3</text>\r\n",
       "<text text-anchor=\"middle\" x=\"53.5\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = &#45;7.289</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;1 -->\r\n",
       "<g id=\"edge1\" class=\"edge\"><title>0&#45;&gt;1</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M93.8154,-88.9485C88.0536,-80.1664 81.8185,-70.6629 76.0136,-61.815\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"78.7975,-59.6779 70.3855,-53.2367 72.9447,-63.5178 78.7975,-59.6779\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"65.3022\" y=\"-74.0145\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">True</text>\r\n",
       "</g>\r\n",
       "<!-- 2 -->\r\n",
       "<g id=\"node3\" class=\"node\"><title>2</title>\r\n",
       "<path fill=\"#e58139\" stroke=\"black\" d=\"M218,-53C218,-53 137,-53 137,-53 131,-53 125,-47 125,-41 125,-41 125,-12 125,-12 125,-6 131,-0 137,-0 137,-0 218,-0 218,-0 224,-0 230,-6 230,-12 230,-12 230,-41 230,-41 230,-47 224,-53 218,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"177.5\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">mse = 37.838</text>\r\n",
       "<text text-anchor=\"middle\" x=\"177.5\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 3</text>\r\n",
       "<text text-anchor=\"middle\" x=\"177.5\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = 7.289</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;2 -->\r\n",
       "<g id=\"edge2\" class=\"edge\"><title>0&#45;&gt;2</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M137.185,-88.9485C142.946,-80.1664 149.181,-70.6629 154.986,-61.815\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"158.055,-63.5178 160.615,-53.2367 152.203,-59.6779 158.055,-63.5178\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"165.698\" y=\"-74.0145\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">False</text>\r\n",
       "</g>\r\n",
       "</g>\r\n",
       "</svg>\r\n"
      ],
      "text/plain": [
       "<graphviz.files.Source at 0x2858c369340>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "iterate(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In iteration 1, we firstly train a tree using residuals_1. This tree tells us height is also important in determining weight. Whose who are less than 1.55 meters are supported to lose -7.289 and the other to gain 7.289. Again, we want to shrink this to 20%, which is -1.4578 and 1.4578. We then make prediction_2 based on prediction_1 and $\\gamma$. We see Alex gains 1.4578 kg because he is 1.6. Others also gain or lose weight because of the new rule."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>name</th>\n",
       "      <th>height</th>\n",
       "      <th>gender</th>\n",
       "      <th>weight</th>\n",
       "      <th>$f_2$</th>\n",
       "      <th>$r_2$</th>\n",
       "      <th>$\\gamma_3$</th>\n",
       "      <th>$f_3$</th>\n",
       "      <th>$r_3$</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Alex</td>\n",
       "      <td>1.6</td>\n",
       "      <td>male</td>\n",
       "      <td>88</td>\n",
       "      <td>74.257778</td>\n",
       "      <td>13.742222</td>\n",
       "      <td>6.047407</td>\n",
       "      <td>75.467259</td>\n",
       "      <td>12.532741</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Brunei</td>\n",
       "      <td>1.6</td>\n",
       "      <td>female</td>\n",
       "      <td>76</td>\n",
       "      <td>70.991111</td>\n",
       "      <td>5.008889</td>\n",
       "      <td>-6.047407</td>\n",
       "      <td>69.781630</td>\n",
       "      <td>6.218370</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Candy</td>\n",
       "      <td>1.5</td>\n",
       "      <td>female</td>\n",
       "      <td>56</td>\n",
       "      <td>68.075556</td>\n",
       "      <td>-12.075556</td>\n",
       "      <td>-6.047407</td>\n",
       "      <td>66.866074</td>\n",
       "      <td>-10.866074</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>David</td>\n",
       "      <td>1.8</td>\n",
       "      <td>male</td>\n",
       "      <td>73</td>\n",
       "      <td>74.257778</td>\n",
       "      <td>-1.257778</td>\n",
       "      <td>6.047407</td>\n",
       "      <td>75.467259</td>\n",
       "      <td>-2.467259</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Eric</td>\n",
       "      <td>1.5</td>\n",
       "      <td>male</td>\n",
       "      <td>77</td>\n",
       "      <td>71.342222</td>\n",
       "      <td>5.657778</td>\n",
       "      <td>6.047407</td>\n",
       "      <td>72.551704</td>\n",
       "      <td>4.448296</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>Felicity</td>\n",
       "      <td>1.4</td>\n",
       "      <td>female</td>\n",
       "      <td>57</td>\n",
       "      <td>68.075556</td>\n",
       "      <td>-11.075556</td>\n",
       "      <td>-6.047407</td>\n",
       "      <td>66.866074</td>\n",
       "      <td>-9.866074</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       name  height  gender  weight      $f_2$      $r_2$  $\\gamma_3$  \\\n",
       "0      Alex     1.6    male      88  74.257778  13.742222    6.047407   \n",
       "1    Brunei     1.6  female      76  70.991111   5.008889   -6.047407   \n",
       "2     Candy     1.5  female      56  68.075556 -12.075556   -6.047407   \n",
       "3     David     1.8    male      73  74.257778  -1.257778    6.047407   \n",
       "4      Eric     1.5    male      77  71.342222   5.657778    6.047407   \n",
       "5  Felicity     1.4  female      57  68.075556 -11.075556   -6.047407   \n",
       "\n",
       "       $f_3$      $r_3$  \n",
       "0  75.467259  12.532741  \n",
       "1  69.781630   6.218370  \n",
       "2  66.866074 -10.866074  \n",
       "3  75.467259  -2.467259  \n",
       "4  72.551704   4.448296  \n",
       "5  66.866074  -9.866074  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
       "<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
       " -->\r\n",
       "<!-- Title: Tree Pages: 1 -->\r\n",
       "<svg width=\"238pt\" height=\"165pt\"\r\n",
       " viewBox=\"0.00 0.00 238.00 165.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 161)\">\r\n",
       "<title>Tree</title>\r\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-161 234,-161 234,4 -4,4\"/>\r\n",
       "<!-- 0 -->\r\n",
       "<g id=\"node1\" class=\"node\"><title>0</title>\r\n",
       "<path fill=\"#f2c09c\" stroke=\"black\" d=\"M156.5,-157C156.5,-157 74.5,-157 74.5,-157 68.5,-157 62.5,-151 62.5,-145 62.5,-145 62.5,-101 62.5,-101 62.5,-95 68.5,-89 74.5,-89 74.5,-89 156.5,-89 156.5,-89 162.5,-89 168.5,-95 168.5,-101 168.5,-101 168.5,-145 168.5,-145 168.5,-151 162.5,-157 156.5,-157\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"115.5\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">gender &lt;= 0.5</text>\r\n",
       "<text text-anchor=\"middle\" x=\"115.5\" y=\"-126.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">mse = 86.003</text>\r\n",
       "<text text-anchor=\"middle\" x=\"115.5\" y=\"-111.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 6</text>\r\n",
       "<text text-anchor=\"middle\" x=\"115.5\" y=\"-96.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = 0.0</text>\r\n",
       "</g>\r\n",
       "<!-- 1 -->\r\n",
       "<g id=\"node2\" class=\"node\"><title>1</title>\r\n",
       "<path fill=\"#ffffff\" stroke=\"black\" d=\"M95,-53C95,-53 12,-53 12,-53 6,-53 0,-47 0,-41 0,-41 0,-12 0,-12 0,-6 6,-0 12,-0 12,-0 95,-0 95,-0 101,-0 107,-6 107,-12 107,-12 107,-41 107,-41 107,-47 101,-53 95,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"53.5\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">mse = 61.288</text>\r\n",
       "<text text-anchor=\"middle\" x=\"53.5\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 3</text>\r\n",
       "<text text-anchor=\"middle\" x=\"53.5\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = &#45;6.047</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;1 -->\r\n",
       "<g id=\"edge1\" class=\"edge\"><title>0&#45;&gt;1</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M93.8154,-88.9485C88.0536,-80.1664 81.8185,-70.6629 76.0136,-61.815\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"78.7975,-59.6779 70.3855,-53.2367 72.9447,-63.5178 78.7975,-59.6779\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"65.3022\" y=\"-74.0145\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">True</text>\r\n",
       "</g>\r\n",
       "<!-- 2 -->\r\n",
       "<g id=\"node3\" class=\"node\"><title>2</title>\r\n",
       "<path fill=\"#e58139\" stroke=\"black\" d=\"M218,-53C218,-53 137,-53 137,-53 131,-53 125,-47 125,-41 125,-41 125,-12 125,-12 125,-6 131,-0 137,-0 137,-0 218,-0 218,-0 224,-0 230,-6 230,-12 230,-12 230,-41 230,-41 230,-47 224,-53 218,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"177.5\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">mse = 37.576</text>\r\n",
       "<text text-anchor=\"middle\" x=\"177.5\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 3</text>\r\n",
       "<text text-anchor=\"middle\" x=\"177.5\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = 6.047</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;2 -->\r\n",
       "<g id=\"edge2\" class=\"edge\"><title>0&#45;&gt;2</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M137.185,-88.9485C142.946,-80.1664 149.181,-70.6629 154.986,-61.815\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"158.055,-63.5178 160.615,-53.2367 152.203,-59.6779 158.055,-63.5178\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"165.698\" y=\"-74.0145\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">False</text>\r\n",
       "</g>\r\n",
       "</g>\r\n",
       "</svg>\r\n"
      ],
      "text/plain": [
       "<graphviz.files.Source at 0x2858be5acd0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "iterate(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Iteration 2 again tells us that gender matters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>name</th>\n",
       "      <th>height</th>\n",
       "      <th>gender</th>\n",
       "      <th>weight</th>\n",
       "      <th>$f_3$</th>\n",
       "      <th>$r_3$</th>\n",
       "      <th>$\\gamma_4$</th>\n",
       "      <th>$f_4$</th>\n",
       "      <th>$r_4$</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Alex</td>\n",
       "      <td>1.6</td>\n",
       "      <td>male</td>\n",
       "      <td>88</td>\n",
       "      <td>75.467259</td>\n",
       "      <td>12.532741</td>\n",
       "      <td>5.427951</td>\n",
       "      <td>76.552849</td>\n",
       "      <td>11.447151</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Brunei</td>\n",
       "      <td>1.6</td>\n",
       "      <td>female</td>\n",
       "      <td>76</td>\n",
       "      <td>69.781630</td>\n",
       "      <td>6.218370</td>\n",
       "      <td>5.427951</td>\n",
       "      <td>70.867220</td>\n",
       "      <td>5.132780</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Candy</td>\n",
       "      <td>1.5</td>\n",
       "      <td>female</td>\n",
       "      <td>56</td>\n",
       "      <td>66.866074</td>\n",
       "      <td>-10.866074</td>\n",
       "      <td>-5.427951</td>\n",
       "      <td>65.780484</td>\n",
       "      <td>-9.780484</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>David</td>\n",
       "      <td>1.8</td>\n",
       "      <td>male</td>\n",
       "      <td>73</td>\n",
       "      <td>75.467259</td>\n",
       "      <td>-2.467259</td>\n",
       "      <td>5.427951</td>\n",
       "      <td>76.552849</td>\n",
       "      <td>-3.552849</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Eric</td>\n",
       "      <td>1.5</td>\n",
       "      <td>male</td>\n",
       "      <td>77</td>\n",
       "      <td>72.551704</td>\n",
       "      <td>4.448296</td>\n",
       "      <td>-5.427951</td>\n",
       "      <td>71.466114</td>\n",
       "      <td>5.533886</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>Felicity</td>\n",
       "      <td>1.4</td>\n",
       "      <td>female</td>\n",
       "      <td>57</td>\n",
       "      <td>66.866074</td>\n",
       "      <td>-9.866074</td>\n",
       "      <td>-5.427951</td>\n",
       "      <td>65.780484</td>\n",
       "      <td>-8.780484</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       name  height  gender  weight      $f_3$      $r_3$  $\\gamma_4$  \\\n",
       "0      Alex     1.6    male      88  75.467259  12.532741    5.427951   \n",
       "1    Brunei     1.6  female      76  69.781630   6.218370    5.427951   \n",
       "2     Candy     1.5  female      56  66.866074 -10.866074   -5.427951   \n",
       "3     David     1.8    male      73  75.467259  -2.467259    5.427951   \n",
       "4      Eric     1.5    male      77  72.551704   4.448296   -5.427951   \n",
       "5  Felicity     1.4  female      57  66.866074  -9.866074   -5.427951   \n",
       "\n",
       "       $f_4$      $r_4$  \n",
       "0  76.552849  11.447151  \n",
       "1  70.867220   5.132780  \n",
       "2  65.780484  -9.780484  \n",
       "3  76.552849  -3.552849  \n",
       "4  71.466114   5.533886  \n",
       "5  65.780484  -8.780484  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
       "<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
       " -->\r\n",
       "<!-- Title: Tree Pages: 1 -->\r\n",
       "<svg width=\"238pt\" height=\"165pt\"\r\n",
       " viewBox=\"0.00 0.00 238.00 165.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 161)\">\r\n",
       "<title>Tree</title>\r\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-161 234,-161 234,4 -4,4\"/>\r\n",
       "<!-- 0 -->\r\n",
       "<g id=\"node1\" class=\"node\"><title>0</title>\r\n",
       "<path fill=\"#f2c09c\" stroke=\"black\" d=\"M157,-157C157,-157 74,-157 74,-157 68,-157 62,-151 62,-145 62,-145 62,-101 62,-101 62,-95 68,-89 74,-89 74,-89 157,-89 157,-89 163,-89 169,-95 169,-101 169,-101 169,-145 169,-145 169,-151 163,-157 157,-157\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"115.5\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">height &lt;= 1.55</text>\r\n",
       "<text text-anchor=\"middle\" x=\"115.5\" y=\"-126.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">mse = 72.837</text>\r\n",
       "<text text-anchor=\"middle\" x=\"115.5\" y=\"-111.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 6</text>\r\n",
       "<text text-anchor=\"middle\" x=\"115.5\" y=\"-96.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = 0.0</text>\r\n",
       "</g>\r\n",
       "<!-- 1 -->\r\n",
       "<g id=\"node2\" class=\"node\"><title>1</title>\r\n",
       "<path fill=\"#ffffff\" stroke=\"black\" d=\"M95,-53C95,-53 12,-53 12,-53 6,-53 0,-47 0,-41 0,-41 0,-12 0,-12 0,-6 6,-0 12,-0 12,-0 95,-0 95,-0 101,-0 107,-6 107,-12 107,-12 107,-41 107,-41 107,-47 101,-53 95,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"53.5\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">mse = 48.937</text>\r\n",
       "<text text-anchor=\"middle\" x=\"53.5\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 3</text>\r\n",
       "<text text-anchor=\"middle\" x=\"53.5\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = &#45;5.428</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;1 -->\r\n",
       "<g id=\"edge1\" class=\"edge\"><title>0&#45;&gt;1</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M93.8154,-88.9485C88.0536,-80.1664 81.8185,-70.6629 76.0136,-61.815\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"78.7975,-59.6779 70.3855,-53.2367 72.9447,-63.5178 78.7975,-59.6779\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"65.3022\" y=\"-74.0145\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">True</text>\r\n",
       "</g>\r\n",
       "<!-- 2 -->\r\n",
       "<g id=\"node3\" class=\"node\"><title>2</title>\r\n",
       "<path fill=\"#e58139\" stroke=\"black\" d=\"M218,-53C218,-53 137,-53 137,-53 131,-53 125,-47 125,-41 125,-41 125,-12 125,-12 125,-6 131,-0 137,-0 137,-0 218,-0 218,-0 224,-0 230,-6 230,-12 230,-12 230,-41 230,-41 230,-47 224,-53 218,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"177.5\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">mse = 37.812</text>\r\n",
       "<text text-anchor=\"middle\" x=\"177.5\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 3</text>\r\n",
       "<text text-anchor=\"middle\" x=\"177.5\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = 5.428</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;2 -->\r\n",
       "<g id=\"edge2\" class=\"edge\"><title>0&#45;&gt;2</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M137.185,-88.9485C142.946,-80.1664 149.181,-70.6629 154.986,-61.815\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"158.055,-63.5178 160.615,-53.2367 152.203,-59.6779 158.055,-63.5178\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"165.698\" y=\"-74.0145\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">False</text>\r\n",
       "</g>\r\n",
       "</g>\r\n",
       "</svg>\r\n"
      ],
      "text/plain": [
       "<graphviz.files.Source at 0x2858c369340>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "iterate(3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Iteration 3 argues that height is important, too."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>name</th>\n",
       "      <th>height</th>\n",
       "      <th>gender</th>\n",
       "      <th>weight</th>\n",
       "      <th>$f_4$</th>\n",
       "      <th>$r_4$</th>\n",
       "      <th>$\\gamma_5$</th>\n",
       "      <th>$f_5$</th>\n",
       "      <th>$r_5$</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Alex</td>\n",
       "      <td>1.6</td>\n",
       "      <td>male</td>\n",
       "      <td>88</td>\n",
       "      <td>76.552849</td>\n",
       "      <td>11.447151</td>\n",
       "      <td>4.476063</td>\n",
       "      <td>77.448062</td>\n",
       "      <td>10.551938</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Brunei</td>\n",
       "      <td>1.6</td>\n",
       "      <td>female</td>\n",
       "      <td>76</td>\n",
       "      <td>70.867220</td>\n",
       "      <td>5.132780</td>\n",
       "      <td>-4.476063</td>\n",
       "      <td>69.972007</td>\n",
       "      <td>6.027993</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Candy</td>\n",
       "      <td>1.5</td>\n",
       "      <td>female</td>\n",
       "      <td>56</td>\n",
       "      <td>65.780484</td>\n",
       "      <td>-9.780484</td>\n",
       "      <td>-4.476063</td>\n",
       "      <td>64.885271</td>\n",
       "      <td>-8.885271</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>David</td>\n",
       "      <td>1.8</td>\n",
       "      <td>male</td>\n",
       "      <td>73</td>\n",
       "      <td>76.552849</td>\n",
       "      <td>-3.552849</td>\n",
       "      <td>4.476063</td>\n",
       "      <td>77.448062</td>\n",
       "      <td>-4.448062</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Eric</td>\n",
       "      <td>1.5</td>\n",
       "      <td>male</td>\n",
       "      <td>77</td>\n",
       "      <td>71.466114</td>\n",
       "      <td>5.533886</td>\n",
       "      <td>4.476063</td>\n",
       "      <td>72.361326</td>\n",
       "      <td>4.638674</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>Felicity</td>\n",
       "      <td>1.4</td>\n",
       "      <td>female</td>\n",
       "      <td>57</td>\n",
       "      <td>65.780484</td>\n",
       "      <td>-8.780484</td>\n",
       "      <td>-4.476063</td>\n",
       "      <td>64.885271</td>\n",
       "      <td>-7.885271</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       name  height  gender  weight      $f_4$      $r_4$  $\\gamma_5$  \\\n",
       "0      Alex     1.6    male      88  76.552849  11.447151    4.476063   \n",
       "1    Brunei     1.6  female      76  70.867220   5.132780   -4.476063   \n",
       "2     Candy     1.5  female      56  65.780484  -9.780484   -4.476063   \n",
       "3     David     1.8    male      73  76.552849  -3.552849    4.476063   \n",
       "4      Eric     1.5    male      77  71.466114   5.533886    4.476063   \n",
       "5  Felicity     1.4  female      57  65.780484  -8.780484   -4.476063   \n",
       "\n",
       "       $f_5$      $r_5$  \n",
       "0  77.448062  10.551938  \n",
       "1  69.972007   6.027993  \n",
       "2  64.885271  -8.885271  \n",
       "3  77.448062  -4.448062  \n",
       "4  72.361326   4.638674  \n",
       "5  64.885271  -7.885271  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
       "<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
       " -->\r\n",
       "<!-- Title: Tree Pages: 1 -->\r\n",
       "<svg width=\"238pt\" height=\"165pt\"\r\n",
       " viewBox=\"0.00 0.00 238.00 165.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 161)\">\r\n",
       "<title>Tree</title>\r\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-161 234,-161 234,4 -4,4\"/>\r\n",
       "<!-- 0 -->\r\n",
       "<g id=\"node1\" class=\"node\"><title>0</title>\r\n",
       "<path fill=\"#f2c09c\" stroke=\"black\" d=\"M156.5,-157C156.5,-157 74.5,-157 74.5,-157 68.5,-157 62.5,-151 62.5,-145 62.5,-145 62.5,-101 62.5,-101 62.5,-95 68.5,-89 74.5,-89 74.5,-89 156.5,-89 156.5,-89 162.5,-89 168.5,-95 168.5,-101 168.5,-101 168.5,-145 168.5,-145 168.5,-151 162.5,-157 156.5,-157\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"115.5\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">gender &lt;= 0.5</text>\r\n",
       "<text text-anchor=\"middle\" x=\"115.5\" y=\"-126.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">mse = 62.231</text>\r\n",
       "<text text-anchor=\"middle\" x=\"115.5\" y=\"-111.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 6</text>\r\n",
       "<text text-anchor=\"middle\" x=\"115.5\" y=\"-96.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = 0.0</text>\r\n",
       "</g>\r\n",
       "<!-- 1 -->\r\n",
       "<g id=\"node2\" class=\"node\"><title>1</title>\r\n",
       "<path fill=\"#ffffff\" stroke=\"black\" d=\"M95,-53C95,-53 12,-53 12,-53 6,-53 0,-47 0,-41 0,-41 0,-12 0,-12 0,-6 6,-0 12,-0 12,-0 95,-0 95,-0 101,-0 107,-6 107,-12 107,-12 107,-41 107,-41 107,-47 101,-53 95,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"53.5\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">mse = 46.332</text>\r\n",
       "<text text-anchor=\"middle\" x=\"53.5\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 3</text>\r\n",
       "<text text-anchor=\"middle\" x=\"53.5\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = &#45;4.476</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;1 -->\r\n",
       "<g id=\"edge1\" class=\"edge\"><title>0&#45;&gt;1</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M93.8154,-88.9485C88.0536,-80.1664 81.8185,-70.6629 76.0136,-61.815\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"78.7975,-59.6779 70.3855,-53.2367 72.9447,-63.5178 78.7975,-59.6779\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"65.3022\" y=\"-74.0145\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">True</text>\r\n",
       "</g>\r\n",
       "<!-- 2 -->\r\n",
       "<g id=\"node3\" class=\"node\"><title>2</title>\r\n",
       "<path fill=\"#e58139\" stroke=\"black\" d=\"M218,-53C218,-53 137,-53 137,-53 131,-53 125,-47 125,-41 125,-41 125,-12 125,-12 125,-6 131,-0 137,-0 137,-0 218,-0 218,-0 224,-0 230,-6 230,-12 230,-12 230,-41 230,-41 230,-47 224,-53 218,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"177.5\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">mse = 38.059</text>\r\n",
       "<text text-anchor=\"middle\" x=\"177.5\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 3</text>\r\n",
       "<text text-anchor=\"middle\" x=\"177.5\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = 4.476</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;2 -->\r\n",
       "<g id=\"edge2\" class=\"edge\"><title>0&#45;&gt;2</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M137.185,-88.9485C142.946,-80.1664 149.181,-70.6629 154.986,-61.815\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"158.055,-63.5178 160.615,-53.2367 152.203,-59.6779 158.055,-63.5178\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"165.698\" y=\"-74.0145\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">False</text>\r\n",
       "</g>\r\n",
       "</g>\r\n",
       "</svg>\r\n"
      ],
      "text/plain": [
       "<graphviz.files.Source at 0x2858be5acd0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "iterate(4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's stop at Iteration 4. And take a look at the loss."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEWCAYAAACJ0YulAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXhU9b3H8fc3CQkQdgiEPWyihFUCiMjiBmq9grVarCIqirTYau2m3XuvvddWu+itqIgIKG61Kt7aCoiySNkCsu/7Tti3ANm+948Z8qQYIGAmZ5L5vJ6HZzLnzJz5Tvs4nznn95vvz9wdERERgLigCxARkeihUBARkUIKBRERKaRQEBGRQgoFEREppFAQEZFCCgWpkMxss5ldF3QdAGb2pJntM7PdQdcicj4KBZEIMrOmwA+Adu6eepbHVDezP4aD7LiZbTWzd82se5HHeHjfsXDAvGlmtYrsn25mJ4vsf8/MGkb+HUpFo1AQiazmwH53zypup5klAZ8CHYCbgRrAZcBbwE1nPLyTu1cDWgK1gV+fsf/h8P7WQDXgmVJ6DxJDFApS4ZlZkpn92cx2hv/9OfxhjJnVM7O/m9khMztgZrPMLC687ydmtsPMjprZGjO79izHr2lmE8xsr5ltMbOfm1lc+PLVVKBR+Bv8uGKePgRoAgxy9+Xunu/ux939XXf/dXGv5+5HgA+BdmfZfwj4AOh8Qf9DiQAJQRcgUgZ+BlxB6EPSgUnAz4FfELq0sx1ICT/2CsDNrC3wMNDN3XeaWRoQf5bj/y9Qk9A3+LrAFGCXu79iZjcCr7t7k7M89zpgsrsfL+mbMbPawCBg7ln21wW+Dqwv6TFFTtOZgsSCu4D/dPcsd98L/IbQN3SAXKAh0Nzdc919locaguUDSUA7M6vk7pvdfcOZBzazeOCbwBPuftTdNwN/KHL886kHFA5Am1nn8FnLETNbc8ZjF5nZIWAf0Ax46Yz9z5nZ4fD+esB3S1iDSCGFgsSCRsCWIve3hLcBPE3oG/UUM9toZo8DuPt64FFC1+2zzOwtM2vEl9UDEos5fuMS1rafUCgRft3F7l6L0Df9pDMee3l4X2XgBWCWmVUusv977l4T6EhozOFsZyciZ6VQkFiwk9CA72nNwtsIf7v/gbu3BP4DeOz02IG7v+HuV4Wf68Dvijn2PkJnG2cef0cJa5sG9Dez5JK+GXfPBcYALYD2xexfBjwJPG9mVtLjioBCQWLDm8DPzSzFzOoBvwReBzCzm82sdfjD8wihy0b5ZtbWzK4JD0ifBE6E9/0bd88H3gF+G55a2hx47PTxS2ACsAt438zam1l8+Nt/xtmeEL5kdV+4po1nedh4oD5wSwnrEAEUChIbngQygaXAMmBReBtAG+AT4BgwBxjl7tMJXbp5itCZwG5CH7A/PcvxvwscJ/QB/TnwBjC2JIW5+0ngamAl8BGhYFoDdAPuOOPhS8zsGHAQGArc6u4HznLcHOA5QoPpIiVmWmRHRERO05mCiIgUUiiIiEghhYKIiBRSKIiISKFy3eaiXr16npaWFnQZIiLlysKFC/e5e0px+8p1KKSlpZGZmRl0GSIi5YqZbTnbPl0+EhGRQgoFEREppFAQEZFCEQsFMxtrZllmtrzItv8ys6VmttjMphTtOmlmT5jZ+vBiJgMiVZeIiJxdJM8UxgE3nLHtaXfv6O6dgb8TakyGmbUDBgPp4eeMCjf9EhGRMhSxUHD3mcCBM7YdKXI3mVA7YoCBwFvufsrdNxHqb98dEREpU2U+JdXMfgvcAxwm1B0SQguSFF1acDtnWaTEzIYDwwGaNWsWuUJFRGJQmQ80u/vP3L0pMJHQGrgAxS0EUmz7Vncf7e4Z7p6RklLsby/OKzsnj19/uILDJ3Iv6vkiIhVVkLOP3gBuC/+9HWhaZF8TwitjRcLKnUd4Y95W7n11PsdP5UXqZUREyp0yDQUza1Pk7i3A6vDfHwKDzSzJzFoQWvhkfqTqyEirw3N3dmHp9sM8MD6Tk7lfWlBLRCQmRXJK6puEVrJqa2bbzWwY8JSZLTezpUB/4BEAd19BaEnDlcDHwMjwMocRc0P7VJ65vSNzN+3nOxMXkZNXEMmXExEpF8r1ymsZGRn+VXsfTZy3hZ+9v5yvdWjIc3d2IT5O65yLSMVmZgvdvdh1wMt1Q7zScFeP5mSfyue3/1hFlcR4fn9bR+IUDCISo2I+FAAe7NOSY6fyeHbaOpIT4/n1LemYKRhEJPYoFMIeva4N2Tl5vDxrE8lJCfz4hkuDLklEpMwpFMLMjJ/edBnHc/IZNX0DyUkJjLy6ddBliYiUKYVCEWbGkwPbcyInn6cnr6FqYjz39WoRdFkiImVGoXCGuDjj6W90JDsnj9/830qSExO4o1vT8z9RRKQC0HoKxUiIj+O5O7vQ55IUfvLeUv5vScR+XC0iElUUCmeRlBDPS3d3pVvzOnz/7cV8snJP0CWJiEScQuEcqiTG88q9GaQ3qsF33ljE7PX7gi5JRCSiFArnUb1yJcbd150WdZN5YHwmC7ccOP+TRETKKYVCCdROTuS1B7qTWrMy9766gOU7DgddkohIRCgUSqh+9cq8/kAPalSuxJBX5rFuz9GgSxIRKXUKhQvQuFYVJj7Qg4T4OO4aM48t+48HXZKISKlSKFygtHrJvD6sB7n5BXzr5XnsPHQi6JJEREqNQuEitE2tzoT7e3DkRC53j5nH3qOngi5JRKRUKBQuUocmNRl7Xzd2HT7JkFfmcSg7J+iSRES+MoXCV9AtrQ4v35PBxr3HGfrqAo5pvWcRKeciuRznWDPLMrPlRbY9bWarzWypmb1vZrWK7HvCzNab2RozGxCpukrbVW3q8ZdvdWH5jsMMG7eAEzla71lEyq9InimMA244Y9tUoL27dwTWAk8AmFk7YDCQHn7OKDOLj2Btpap/eip/vKMT8zcfYMTrCzmVp2AQkfIpYqHg7jOBA2dsm+Lup6+xzAWahP8eCLzl7qfcfROwHugeqdoiYWDnxvzPrR2YsXYvj7y5mLz8gqBLEhG5YEGOKdwP/DP8d2NgW5F928PbypXB3Zvxi5vb8fGK3fz43aUUFHjQJYmIXJBA1lMws58BecDE05uKeVixn6hmNhwYDtCsWbOI1PdVDLuqBdmn8vjD1LVUSYznyUHttd6ziJQbZR4KZjYUuBm41t1Pf/BvB4quZNMEKHYRA3cfDYwGyMjIiMqv4g9f05pjOXm8NGMj1ZISePzGSxUMIlIulGkomNkNwE+Avu6eXWTXh8AbZvZHoBHQBphflrWVJjPj8RsuJftUPi/N3EhyUgLfu7ZN0GWJiJxXxELBzN4E+gH1zGw78CtCs42SgKnhb85z3X2Eu68ws3eAlYQuK41093I9hcfM+M0t6RzPyeOPU9dSNTGeB3q3DLosEZFzilgouPudxWx+5RyP/y3w20jVE4S4OOP3t3XkRE4+T360iuSkBO7sHn3jICIip+kXzRGWEB/Hs4O70K9tCj99fxmTFu8IuiQRkbNSKJSBxIQ4Xry7Kz1a1OGxd5YwecXuoEsSESmWQqGMVK4Uz5ih3ejQuCbffeMLZq3bG3RJIiJfolAoQ9WSEhh/X3da1a/GgxMymb9J6z2LSHRRKJSxmlUr8dqw7jSqVYX7xy1g6fZDQZckIlJIoRCAetWSmPhAD2pVrcQ9Y+ezZrfWexaR6KBQCEjDmqH1npMSQus9b9qn9Z5FJHgKhQA1r5vMxAd6UODOXS/PZYfWexaRgCkUAta6fnUm3N+do6fyuOvluWQdORl0SSISwxQKUaB945qMu68bWUdPcfcr8zh4XOs9i0gwFApRomvzOoy5J4PN+7MZ+up8jp7MDbokEYlBCoUocmXrerxw1+Ws3HmE+8ctIDsn7/xPEhEpRQqFKHPtZQ348+DOLNxykIde03rPIlK2FApR6OaOjXjqto7MWrePh9/4glyt9ywiZUShEKXuyGjKr/+jHVNX7uGHf11CvtZ7FpEyEMgazVIy9/ZqwfGcfJ6evIaqifH8960dtKyniESUQiHKjby6NcdP5TFq+gaqJibw869dpmAQkYhRKJQDPxrQluycfF75fBPJSQk8dv0lQZckIhVUxMYUzGysmWWZ2fIi2243sxVmVmBmGWc8/gkzW29ma8xsQKTqKo/MjF/e3I7buzbhuWnrGD1zQ9AliUgFFcmB5nHADWdsWw58HZhZdKOZtQMGA+nh54wys/gI1lbuxMUZT93Wka91bMh//2M1r83dEnRJIlIBRezykbvPNLO0M7atAoq7Jj4QeMvdTwGbzGw90B2YE6n6yqP4OONPd3TmZE4+v/hgOcmJ8Xz98iZBlyUiFUi0TEltDGwrcn97eNuXmNlwM8s0s8y9e2NvScvEhDiev+tyrmxVlx/+dQn/XLYr6JJEpAKJllAobjpNsRPz3X20u2e4e0ZKSkqEy4pOlSvF8/I9GXRuWovvvfUFn63JCrokEakgoiUUtgNNi9xvAuwMqJZyITkpgVfv684lDaoz4rWFzN24P+iSRKQCiJZQ+BAYbGZJZtYCaAPMD7imqFezSiUm3N+dpnWqMmzcAhZv03rPIvLVRHJK6puEBorbmtl2MxtmZrea2XagJ/CRmU0GcPcVwDvASuBjYKS7qxNcCdStlsTrw3pQt1oSQ8fOZ9WuI0GXJCLlmLmX3546GRkZnpmZGXQZUWHbgWxuf3EOeQUFvP1QT1qlVAu6JBGJUma20N0zitsXLZeP5CtqWqcqrz/QA3e4e8w8th3IDrokESmHFAoVSOv61XhtWA+On8rjrjHz2KP1nkXkAikUKph2jWow/v7u7D92irvHzGP/sVNBlyQi5YhCoQLq0qw2Y4Z2Y+uBbO4ZO5/DJ7Tes4iUjEKhgurZqi4v3t2VtXuOar1nESkxhUIFdvWl9XlucBe+2HqQBydkcjJXs3xF5NwUChXcjR0a8vQ3OjF7/X5GTlzEqTwFg4icnUIhBtzWtQn/Nag901Zn8a2X57FPg88ichYKhRgx5Irm/OVbXVi+4zAD/zKb1bv1y2cR+TKFQgy5uWMj3nmoJ7n5Bdw26l98unpP0CWJSJRRKMSYTk1rMenhXqTVS2bY+EzGzNpIeW51IiKlS6EQgxrWrMJfR/RkQLtUnvxoFU+8t4ycvIKgyxKRKKBQiFFVExMYddfljLy6FW8t2MY9Y+dx8HhO0GWJSMAUCjEsLs740YBL+dM3O7FoyyEGjZrN+qxjQZclIgFSKAi3dmnCm8Ov4PipPG4dNZtZ62Jv7WsRCVEoCABdm9fmg5G9aFyrCve+uoAJczYHXZKIBEChIIWa1K7Ku9++kn6XpPDLSSv45aTl5OVrAFokligU5N9US0pg9D0ZDO/TkglztnDfuAXqsioSQyK5RvNYM8sys+VFttUxs6lmti58W7vIvifMbL2ZrTGzAZGqS84vPs746U2X8fvbOjJ3436+Pmo2m/cdD7osESkDkTxTGAfccMa2x4Fp7t4GmBa+j5m1AwYD6eHnjDKz+AjWJiVwR7emvDasB/uP5zBo1GzmbNgfdEkiEmERCwV3nwkcOGPzQGB8+O/xwKAi299y91PuvglYD3SPVG1Scle0rMukkb2om5zIkFfm8faCrUGXJCIRVNZjCg3cfRdA+LZ+eHtjYFuRx20Pb/sSMxtuZplmlrl3r6ZOloXmdZN5f2Qveraqy0/+town/76S/AK1xhCpiKJloNmK2Vbsp467j3b3DHfPSElJiXBZclqNypV49d5u3HtlGmM+38SDEzI5elID0CIVTVmHwh4zawgQvs0Kb98ONC3yuCbAzjKuTc4jIT6OX9+Szn8Nas+MtXv5xgtz2HYgO+iyRKQUlXUofAgMDf89FJhUZPtgM0sysxZAG2B+GdcmJTTkiuaMv687uw6fYNDzs8ncfObQkYiUV5GckvomMAdoa2bbzWwY8BRwvZmtA64P38fdVwDvACuBj4GR7q51I6PYVW3q8f7IXlSvnMC3Xp7He4u2B12SiJQCK8+99DMyMjwzMzPoMmLaoewcvv36IuZs3M93+rXih/3bEhdX3BCRiEQLM1vo7hnF7YuWgWYpp2pVTWTCsO7c2b0po6Zv4NsTF5Kdkxd0WSJykRQK8pVVio/jv2/twC9ubsfUlXu4/cU57Dp8IuiyROQiKBSkVJgZw65qwStDu7Flfza3/GU2S7YdCrosEblACgUpVVdfWp+/fftKkhLiuOOlOfx9qWYWi5QnCgUpdW1TqzNpZC86NK7Jw298wZ8/WUt5ntAgEksUChIRdaslMfHBHnz98sb8+ZN1fO+txZzM1SxjkWiXEHQBUnElJcTzh9s70aZ+dX4/eTVbD2Tz8pCu1K9ROejSROQsdKYgEWVmfLtfK168uytrdx9l4POzWb7jcNBlichZKBSkTAxIT+Xdb/cE4PYX5zB5xe6AKxKR4pQoFMzsETOrYSGvmNkiM+sf6eKkYklvVJNJI3txSWp1Rry+kBemb9AAtEiUKemZwv3ufgToD6QA9xHuWyRyIerXqMzbw6/g5o6N+N3Hq/nBX5dwKk8D0CLRoqQDzaeb2dwEvOruS8xMDW7kolSuFM9zgzvTOqUaf/pkLVv3Z/PSkK7UrZYUdGkiMa+kZwoLzWwKoVCYbGbVgYLIlSUVnZnxyHVt+Mu3urBsx2EGPj+bNbuPBl2WSMwraSgMAx4Hurl7NlCJ0CUkka/k5o6NeOehnuTkFXDbC//is9VZ53+SiERMSUOhJ7DG3Q+Z2d3AzwHNK5RS0alpLSY93IvmdasybPwCxszaqAFokYCUNBReALLNrBPwY2ALMCFiVUnMaVizCn8d0ZP+7VJ58qNV/PT9ZeTk6QqlSFkraSjkeeir20DgWXd/FqgeubIkFlVNTGDUXZcz8upWvDl/G/eMnceh7JygyxKJKSUNhaNm9gQwBPjIzOIJjStclPDvHpab2QozezS8rY6ZTTWzdeHb2hd7fCm/4uKMHw24lD99sxOLthxi0POz2bD3WNBlicSMkobCN4FThH6vsBtoDDx9MS9oZu2BB4HuQCfgZjNrQ2gge5q7twGmhe9LjLq1SxPeHN6DoyfzGPT8bD5fty/okkRiQolCIRwEE4GaZnYzcNLdL3ZM4TJgrrtnu3seMAO4ldClqfHhx4wHBl3k8aWC6Nq8DpMe7kWjmlUY+up8Xpu7JeiSRCq8kra5uAOYD9wO3AHMM7NvXORrLgf6mFldM6tK6LcPTYEG7r4LIHxb/yKPLxVIk9pV+dt3rqTfJSn84oPl/GrScvLyNQAtEikl/UXzzwj9RiELwMxSgE+Ady/0Bd19lZn9DpgKHAOWACVe6d3MhgPDAZo1a3ahLy/lULWkBEbfk8FT/1zFy7M2sXHfcf7yrcupWeWih7VE5CxKOqYQdzoQwvZfwHO/xN1fcffL3b0PcABYB+wxs4YA4dtif8Xk7qPdPcPdM1JSUi62BCln4uOMn32tHb+7rQNzNuzn66Nms2X/8aDLEqlwSvrB/rGZTTaze83sXuAj4B8X+6JmVj982wz4OvAm8CEwNPyQocCkiz2+VFzf7NaM1x/owf7jOQx8fjZzN+4PuiSRCsVK+stRM7sN6EWoOd5Md3//ol/UbBZQF8gFHnP3aWZWF3gHaAZsBW539wPnOk5GRoZnZmZebBlSjm3ed5xh4xew9UA2vx3UgTu6NQ26JJFyw8wWuntGsfvKczsBhUJsO3wil4ffWMSsdfsY3qclP7nhUuLj1LxX5HzOFQrnvHxkZkfN7Egx/46a2ZHIlCtSMjWrVOLVe7sxtGdzRs/cyEOvZXLsVInnLIhIMc4ZCu5e3d1rFPOvurvXKKsiRc4mIT6O3wxsz38NTOezNXv5xgv/YvvB7KDLEim3tEazVAhDeqYx7r5u7Dh0gkHPz2bhloNBlyRSLikUpMLo3SaF97/Ti+SkBO4cPZcPvtgRdEki5Y5CQSqU1vWr8cF3enF581o8+vZifvfxaq0BLXIBFApS4dROTmTC/T24s3szXpi+gf5/msmUFbu1cI9ICSgUpEJKTIjjf77egQn3dycxPo7hry1kyCvztQ60yHkoFKRC63NJCv94pDe//o92LNtxmJuem8UvJy3n4HEt3iNSHIWCVHiV4uO4t1cLpv+wH3f1aMbrc7fQ75npjP/XZnVcFTmDQkFiRu3kRP5zYHv+8Uhv2jeuwa8+XMFNz81i1rq9QZcmEjUUChJzLk2twevDejB6SFdO5hYw5JX5PDA+k8371HVVRKEgMcnM6J+eytTH+vD4jZcyZ8M+rv/TDP7nH6s4ejI36PJEAqNQkJiWlBDPiL6t+OxH/RjUuTGjZ23k6mem8/aCreQXaAqrxB6FgghQv3plnr69E5NG9qJ53WR+8rdlDHz+cxZsPmf3dpEKR6EgUkTHJrV4d0RPnh3cmf3Hcrj9xTk8/MYidhw6EXRpImVCoSByBjNjYOfGTPtBXx65tg1TV+7hmmem88epa8nOUWtuqdgUCiJnUTUxge9ffwmf/rAf/dNTeW7aOq79wwwmLd6hlhlSYSkURM6jca0q/O+dXXjnoZ7UrZbII28t5hsvzmHp9kNBlyZS6gIJBTP7vpmtMLPlZvammVU2szpmNtXM1oVvawdRm8jZdG9Rh0kjr+L3t3Vky/7jDHx+Nj/66xKyjp4MujSRUlPmoWBmjYHvARnu3h6IBwYDjwPT3L0NMC18XySqxMcZd3Rrymc/7Mfw3i35YPEOrn56Oi9M36AW3VIhBHX5KAGoYmYJQFVgJzAQGB/ePx4YFFBtIudVvXIlnrjpMqZ8vy89W9Xjdx+vpv+fZjJZLbqlnCvzUHD3HcAzwFZgF3DY3acADdx9V/gxu4D6xT3fzIabWaaZZe7dq541EqwW9ZIZMzSD14aFWnQ/9NpC7n5lnlp0S7kVxOWj2oTOCloAjYBkM7u7pM9399HunuHuGSkpKZEqU+SC9G6Twj8f6c1vbkln+Y4j3PjsTH7xgVp0S/kTxOWj64BN7r7X3XOB94ArgT1m1hAgfJsVQG0iFy0hPo6hV6Yx/Yf9GHJFc96Yv5V+z0zn1dmbyFWLbiknggiFrcAVZlbVzAy4FlgFfAgMDT9mKDApgNpEvrLayYn8ZmB7/vG93nRoXJPf/N9Kbnx2FjPW6nKnRL8gxhTmAe8Ci4Bl4RpGA08B15vZOuD68H2RcqttanVeG9adl+/JIDe/gKFj5zNs3AI2qUW3RDErzzMlMjIyPDMzM+gyRM7rVF4+r87ezP9OW0dOfgH39WrBw9e0pkblSkGXJjHIzBa6e0Zx+/SLZpEyULRF961dGvPyrI1c88x03pqvFt0SXRQKImWofvXK/P4bnfhw5FWk1U3m8feWcctfPmf+JrXoluigUBAJQIcmNfnriJ48d2cXDhzP4Y6X5jDyjUVsP5gddGkS4xQKIgExM27p1IhPf9CPR69rw7RVe7j2DzP445Q1atEtgVEoiASsSmI8j153CdN+EG7R/el6rnlmBh98oRbdUvYUCiJR4nSL7r+O6Em96ok8+vZibnvhXyzZphbdUnYUCiJRpltaHT4Mt+jeeuAEA5+fzQ/eWULWEbXolshTKIhEobjCFt19eahvSz5csoOrn5nOqOnrOZmrFt0SOQoFkShWvXIlnrjxMqZ+vy9Xtq7H7z9eQ/8/zeTj5WrRLZGhUBApB9LqJfPyPaEW3UkJcYx4fSF3jZnH6t1Hgi5NKhiFgkg5UrRF94qdR7jp2Vn8/INlHFCLbiklCgWRcubMFt1vzt9Gv6c/Y+znatEtX51CQaScOt2i+5+P9KZT01r8599XcsOfZ/Lp6j0ab5CLplAQKecuaVCdCfeHWnTnFzj3j8vkpuc+Z9LiHeTpzEEukFpni1QgOXkFfLB4By/N2MCGvcdpUrsKD/ZuyR0ZTamSGB90eRIlztU6W6EgUgEVFDifrNrDizM2sGjrIeokJzK0Zxr39GxO7eTEoMuTgCkURGKUu7Ng80FenLGBT1dnUaVSPIO7N+WB3i1pXKtK0OVJQM4VCgllXYyIlB0zo3uLOnRvUYfVu48wesZGXpuzhdfmbOGWTo14qG8r2qZWD7pMiSJlfqZgZm2Bt4tsagn8EpgQ3p4GbAbucPeD5zqWzhRELtyOQycYM2sjb83fxoncfK65tD4j+raiW1ptzCzo8qQMRO3lIzOLB3YAPYCRwAF3f8rMHgdqu/tPzvV8hYLIxTt4PIcJc7Ywfs5mDhzP4fJmtRjRtxXXXdaAuDiFQ0UWzaHQH/iVu/cyszVAP3ffZWYNgenu3vZcz1coiHx1J3LyeSdzGy/P2sj2gydoXb8aw/u0ZFDnxiQmaNZ6RRTNoTAWWOTufzGzQ+5eq8i+g+5eu5jnDAeGAzRr1qzrli1byq5gkQosL7+Aj5bt4sUZG1m16wipNSoz7KoW3NmjGdWSNPxYkURlKJhZIrATSHf3PSUNhaJ0piBS+tydGWv38tKMjczZuJ8alRMY0rM5917ZgpTqSUGXJ6UgWmcf3UjoLGFP+P4eM2tY5PJRVoC1icQsM6Nf2/r0a1ufxdsO8dKMDYyavoGXZ23iG12bMLx3S9LqJQddpkRIkBcM7wTeLHL/Q2Bo+O+hwKQyr0hE/k3nprV44e6uTHusL7dd3ph3M7dzzR+mM/KNRSzfcTjo8iQCArl8ZGZVgW1AS3c/HN5WF3gHaAZsBW539wPnOo4uH4mUrawjJxk7ezMT527h6Kk8rmpdjxF9W9GrdV1NZy1HonJMoTQoFESCceRkLm/M28rYzzeRdfQU7RvXYETfVtzYviHxms4a9RQKIhIRp/LyeX/RDkbP3MjGfcdpVqcqD/Zpye1dm1C5khrwRSuFgohEVH6BM3Xlbl6YsZEl2w5Rr1oi916ZxpAr0qhZtVLQ5ckZFAoiUibcnbkbD/DijA3MWLuX5MR47uzejGG9W9CwphrwRQuFgoiUuZU7jzB65gb+b+ku4gwGdm7MiL4taV1fDfiCplAQkcBsO5DNK59v4q0FWzmZW8B1lzXg2/1a0rV5naBLi1kKBREJ3IHjOYz/12bGz9nMoexcuqXVZkTfVlzdtr4a8GpNxuYAAAoQSURBVJUxhYKIRI3snDzeXrCNMbM2sePQCS5pUI2H+rTils6NqBSvBnxlQaEgIlEnN7+Avy/dyUszNrJ691Ea1azMsN4tGdytKclqwBdRCgURiVruzvQ1e3lhxgbmbzpAzSqVGNqzOUOvTKNuNTXgiwSFgoiUC4u2HuTF6RuYumoPSQlx3JHRlAd7t6RpnapBl1ahKBREpFxZn3WM0TM38P4XOyhw+FqHhjzUtyXpjWoGXVqFoFAQkXJpz5GTjP18ExPnbeXYqTz6XJLCiD4t6dlKDfi+CoWCiJRrh0/kMnHeFsZ+vpl9x07RqUlNRvRtRf/0VDXguwgKBRGpEE7m5vPeoh2MnrmBzfuzaVEvmeF9WnJrl8ZqwHcBFAoiUqHkFziTV+zmxRkbWLr9MNUrJ3DNpfUZkJ5K30tSNKX1PKJ1OU4RkYsSH2fc1KEhN7ZPZc7G/XzwxQ4+WZXFpMU7SUyIo3fregxIT+W6dg2ok5wYdLnlikJBRMotM+PKVvW4slU98vILyNxykMkrdjNlxR6mrc4i7j3ollaHAemp9E9vQJPamtp6PkEtx1kLGAO0Bxy4H1gDvA2kAZuBO9z94LmOo8tHIlIcd2fFziNMWbGbySv2sGbPUQDaN67BgHapDGifSpv61WJ2BlPUjSmY2XhglruPMbNEoCrwU+CAuz9lZo8Dtd39J+c6jkJBREpi077j4YDYzaKthwBoUS+Z/ukN6N8ulS5Na8VUU76oCgUzqwEsAVp6kRc3szVAP3ffZWYNgenu3vZcx1IoiMiFyjpykikr9zB5xW7mbNhPXoFTv3oS17drwID0VK5oWZfEhIrdmC/aQqEzMBpYCXQCFgKPADvcvVaRxx1099rnOpZCQUS+isMncvlsdRaTV+xm+pq9nMjNp3rlBK4Nz2TqU0FnMkVbKGQAc4Fe7j7PzJ4FjgDfLUkomNlwYDhAs2bNum7ZsqWMKheRiuxkbj6fr9vH5BW7+WTVHg5m55KUEEfvNvXon57KdZdVnJlM0RYKqcBcd08L3+8NPA60RpePRCQK5OUXsGDz6ZlMu9l5+CRxBt1bnJ7JlErjWuV3zemoCgUAM5sFPODua8zs10ByeNf+IgPNddz9x+c6jkJBRCLN3Vm+4wiTwwPV67KOAdChcU0GpDegf3r5m8kUjaHQmdCU1ERgI3AfEAe8AzQDtgK3u/uBcx1HoSAiZW3j3mOFA9VfnDGTaUB6Kp2bRP9MpqgLhdKiUBCRIO0Jz2SaUmQmU4Ma/z6TKRqXGFUoiIhE2OHsXD5ds4fJy/cwY21oJlONyglce1kDBqQ3oM8lKVRNjI6ZTAoFEZEydDI3n5lr9zJ5xR6mrd7DocKZTCkMSG/AdZc1oHaAM5nUEE9EpAxVrhRP//Aspbz8AuZvPsCUFXsKp7vGxxnd0+oUDlQ3iqKZTDpTEBEpI+7Osh2HwzOZ9rD+jJlMA9JTaV0GM5l0+UhEJApt2HusMCCWbAvNZGpZL5n+6akMSG9ApwjNZFIoiIhEud2HTzJlZei3EHM3HiA/PJOpf7tUBqSn0qNlnVKbyaRQEBEpRw5l5zBtVRZTVu5mxtq9nMwtKNWZTAoFEZFy6kROPjPX7WXyit1MW5XF4RO5VK4Ux909mvPzm9td1DE1+0hEpJyqkhjPgPTQJaTc/ALmbzrA5BW7IzZjSaEgIlJOVIqPo1frevRqXS9irxF9v78WEZHAKBRERKSQQkFERAopFEREpJBCQURECikURESkkEJBREQKKRRERKRQuW5zYWZ7gS1f4RD1gH2lVE55EGvvF/SeY4Xe84Vp7u4pxe0o16HwVZlZ5tn6f1REsfZ+Qe85Vug9lx5dPhIRkUIKBRERKRTroTA66ALKWKy9X9B7jhV6z6UkpscURETk38X6mYKIiBShUBARkUIxGQpmdoOZrTGz9Wb2eND1RJqZjTWzLDNbHnQtZcXMmprZZ2a2ysxWmNkjQdcUaWZW2czmm9mS8Hv+TdA1lQUzizezL8zs70HXUlbMbLOZLTOzxWZWqmsSx9yYgpnFA2uB64HtwALgTndfGWhhEWRmfYBjwAR3bx90PWXBzBoCDd19kZlVBxYCgyr4/88GJLv7MTOrBHwOPOLucwMuLaLM7DEgA6jh7jcHXU9ZMLPNQIa7l/oP9mLxTKE7sN7dN7p7DvAWMDDgmiLK3WcCB4Kuoyy5+y53XxT++yiwCmgcbFWR5SHHwncrhf9V6G99ZtYE+BowJuhaKopYDIXGwLYi97dTwT8sYp2ZpQFdgHnBVhJ54Uspi4EsYKq7V/T3/Gfgx0BB0IWUMQemmNlCMxtemgeOxVCwYrZV6G9TsczMqgF/Ax519yNB1xNp7p7v7p2BJkB3M6uwlwvN7GYgy90XBl1LAHq5++XAjcDI8CXiUhGLobAdaFrkfhNgZ0C1SASFr6v/DZjo7u8FXU9ZcvdDwHTghoBLiaRewC3h6+tvAdeY2evBllQ23H1n+DYLeJ/QZfFSEYuhsABoY2YtzCwRGAx8GHBNUsrCg66vAKvc/Y9B11MWzCzFzGqF/64CXAesDraqyHH3J9y9ibunEfrv+FN3vzvgsiLOzJLDkycws2SgP1BqMwtjLhTcPQ94GJhMaPDxHXdfEWxVkWVmbwJzgLZmtt3MhgVdUxnoBQwh9O1xcfjfTUEXFWENgc/MbCmhLz9T3T1mpmnGkAbA52a2BJgPfOTuH5fWwWNuSqqIiJxdzJ0piIjI2SkURESkkEJBREQKKRRERKSQQkFERAopFETCzOxf4ds0M/tWKR/7pxdSg0hQNCVV5Axm1g/44YV03DSzeHfPP8f+Y+5erTTqE4kknSmIhJnZ6Q6jTwG9wz94+364ydzTZrbAzJaa2UPhx/cLr9nwBrAsvO2DcJOyFacblZnZU0CV8PEmhrc9ZmbLw/8ePbOG8LGnm9m7ZrbazCaGf6UtElEJQRcgEoUep8iZQvjD/bC7dzOzJGC2mU0JP7Y70N7dN4Xv3+/uB8JtJhaY2d/c/XEzezjcqA4z6wrcB/Qg1KBxnpnNcPcvzqijC5BOqDfXbEK/0v48Yu9aBJ0piJREf+CecEvqeUBdoE143/wigQDwvXD7gbmEGi+24cuuAt539+Ph9Q/eA3oX87j57r7d3QuAxUBaqbwbkXPQmYLI+RnwXXef/G8bQ2MPx8+4fx3Q092zzWw6UPksxyuJU0X+zkf/vUoZ0JmCyJcdBaoXuT8Z+Ha4FTdmdkm4O+WZagIHw4FwKXBFkX25p58PzAQGmVnV8HFuBWaV+rsQuQj65iHyZUuBvPBloHHAs4Qu3SwKD/buBQYV87yPgRHhLqVrCF1COm00sNTMFrn7XWY2jlCHS4AxxYwniARCU1JFRKSQLh+JiEghhYKIiBRSKIiISCGFgoiIFFIoiIhIIYWCiIgUUiiIiEih/weew3W3WW8Q7wAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(loss)\n",
    "plt.title('loss of GBR')\n",
    "plt.xlabel('iteratoin')\n",
    "plt.ylabel('loss')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Hope you get the idea.\n",
    "\n",
    "# Reference:\n",
    "\n",
    "https://en.wikipedia.org/wiki/Gradient_boosting\n",
    "\n",
    "https://www.youtube.com/watch?v=3CC4N4z3GJc&list=PLblh5JKOoLUICTaGLRoHQDuF_7q2GfuJF&index=44    \n",
    "\n",
    "https://www.youtube.com/watch?v=2xudPOBz-vs&list=PLblh5JKOoLUICTaGLRoHQDuF_7q2GfuJF&index=45\n",
    "\n",
    "This article is based on this video."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
